{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/9_compare_baselines.ipynb)\n",
    "# Reliably compare algorithm performance\n",
    "\n",
    "Did we actually match the expert performance or was it just luck? Did this hyperparameter change actually improve the performance of our algorithm? These are questions that we need to answer when we want to compare the performance of different algorithms or hyperparameters.\n",
    "\n",
    "`imitation` provides some tools to help you answer these questions. For demonstration purposes, we will use Behavior Cloning on the CartPole-v1 environment. We will compare different variants of the trained algorithm, and also compare it with a more sophisticated algorithm, DAgger."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will start by training a good (but not perfect) expert."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "from stable_baselines3 import PPO\n",
    "from stable_baselines3.ppo import MlpPolicy\n",
    "\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "expert = PPO(\n",
    "    policy=MlpPolicy,\n",
    "    env=env,\n",
    "    seed=0,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0003,\n",
    "    n_epochs=10,\n",
    "    n_steps=64,\n",
    ")\n",
    "expert.learn(10_000)  # set to 100_000 for better performance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For comparison, let's also train a not-quite-expert."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "not_expert = PPO(\n",
    "    policy=MlpPolicy,\n",
    "    env=env,\n",
    "    seed=0,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.0,\n",
    "    learning_rate=0.0003,\n",
    "    n_epochs=10,\n",
    "    n_steps=64,\n",
    ")\n",
    "\n",
    "not_expert.learn(1_000)  # set to 10_000 for slightly better performance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So are they any good? Let's quickly get a point estimate of their performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "env.reset(seed=0)\n",
    "\n",
    "expert_reward, _ = evaluate_policy(expert, env, 1)\n",
    "not_expert_reward, _ = evaluate_policy(not_expert, env, 1)\n",
    "\n",
    "print(f\"Expert reward: {expert_reward:.2f}\")\n",
    "print(f\"Not expert reward: {not_expert_reward:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But wait! We only ran the evaluation once. What if we got lucky? Let's run the evaluation a few more times and see what happens."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_reward, _ = evaluate_policy(expert, env, 10)\n",
    "not_expert_reward, _ = evaluate_policy(not_expert, env, 10)\n",
    "\n",
    "print(f\"Expert reward: {expert_reward:.2f}\")\n",
    "print(f\"Not expert reward: {not_expert_reward:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Seems a bit more robust now, but how certain are we? Fortunately, `imitation` provides us with tools to answer this.\n",
    "\n",
    "We will perform a permutation test using the `is_significant_reward_improvement` function. We want to be very certain -- let's set the bar high and require a p-value of 0.001."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.testing.reward_improvement import is_significant_reward_improvement\n",
    "\n",
    "expert_rewards, _ = evaluate_policy(expert, env, 10, return_episode_rewards=True)\n",
    "not_expert_rewards, _ = evaluate_policy(\n",
    "    not_expert, env, 10, return_episode_rewards=True\n",
    ")\n",
    "\n",
    "significant = is_significant_reward_improvement(\n",
    "    not_expert_rewards, expert_rewards, 0.001\n",
    ")\n",
    "\n",
    "print(\n",
    "    f\"The expert is {'NOT ' if not significant else ''}significantly better than the not-expert.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Huh, turns out we set the bar too high. We could lower our standards, but that's for cowards.\n",
    "Instead, we can collect more data and try again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.testing.reward_improvement import is_significant_reward_improvement\n",
    "\n",
    "expert_rewards, _ = evaluate_policy(expert, env, 100, return_episode_rewards=True)\n",
    "not_expert_rewards, _ = evaluate_policy(\n",
    "    not_expert, env, 100, return_episode_rewards=True\n",
    ")\n",
    "\n",
    "significant = is_significant_reward_improvement(\n",
    "    not_expert_rewards, expert_rewards, 0.001\n",
    ")\n",
    "\n",
    "print(\n",
    "    f\"The expert is {'NOT ' if not significant else ''}significantly better than the not-expert.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we go! We can now be 99.9% confident that the expert is better than the not-expert -- in this specific case, with these specific trained models. It might still be an extraordinary stroke of luck, or a conspiracy to make us choose the wrong algorithm, but outside of that, we can be pretty sure our data's correct."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the same principle to with imitation learning algorithms. Let's train a behavior cloning algorithm and see how it compares to the expert. This time, we can lower the bar to the standard \"scientific\" threshold of 0.05."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Like in the first tutorial, we will start by collecting some expert data. But to spice it up, let's also get some data from the not-quite-expert."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.data import rollout\n",
    "from imitation.data.wrappers import RolloutInfoWrapper\n",
    "from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "import numpy as np\n",
    "\n",
    "rng = np.random.default_rng()\n",
    "expert_rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n",
    "    rng=rng,\n",
    ")\n",
    "expert_transitions = rollout.flatten_trajectories(expert_rollouts)\n",
    "\n",
    "\n",
    "not_expert_rollouts = rollout.rollout(\n",
    "    not_expert,\n",
    "    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=50),\n",
    "    rng=rng,\n",
    ")\n",
    "not_expert_transitions = rollout.flatten_trajectories(not_expert_rollouts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try cloning an expert and a non-expert, and see how they compare."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms import bc\n",
    "\n",
    "expert_bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    demonstrations=expert_transitions,\n",
    "    rng=rng,\n",
    ")\n",
    "\n",
    "not_expert_bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    demonstrations=not_expert_transitions,\n",
    "    rng=rng,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_bc_trainer.train(n_epochs=2)\n",
    "not_expert_bc_trainer.train(n_epochs=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_expert_rewards, _ = evaluate_policy(\n",
    "    expert_bc_trainer.policy, env, 10, return_episode_rewards=True\n",
    ")\n",
    "bc_not_expert_rewards, _ = evaluate_policy(\n",
    "    not_expert_bc_trainer.policy, env, 10, return_episode_rewards=True\n",
    ")\n",
    "significant = is_significant_reward_improvement(\n",
    "    bc_not_expert_rewards, bc_expert_rewards, 0.05\n",
    ")\n",
    "print(f\"Cloned expert rewards: {bc_expert_rewards}\")\n",
    "print(f\"Cloned not-expert rewards: {bc_not_expert_rewards}\")\n",
    "\n",
    "print(\n",
    "    f\"Cloned expert is {'NOT ' if not significant else ''}significantly better than the cloned not-expert.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How about comparing the expert clone to the expert itself?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_clone_rewards, _ = evaluate_policy(\n",
    "    expert_bc_trainer.policy, env, 10, return_episode_rewards=True\n",
    ")\n",
    "\n",
    "expert_rewards, _ = evaluate_policy(expert, env, 10, return_episode_rewards=True)\n",
    "\n",
    "significant = is_significant_reward_improvement(bc_clone_rewards, expert_rewards, 0.05)\n",
    "\n",
    "print(f\"Cloned expert rewards: {bc_clone_rewards}\")\n",
    "print(f\"Expert rewards: {expert_rewards}\")\n",
    "\n",
    "print(\n",
    "    f\"Expert is {'NOT ' if not significant else ''}significantly better than the cloned expert.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Turns out the expert is significantly better than the clone -- again, in this case. Note, however, that this is not proof that the clone is as good as the expert -- there's a subtle difference between the two claims in the context of hypothesis testing.\n",
    "\n",
    "Note: if you changed the duration of the training at the beginning of this tutorial, you might get different results. While this might break the narrative in this tutorial, it's a good learning opportunity.\n",
    "\n",
    "When comparing the performance of two agents, algorithms, hyperparameter sets, always remember the scope of what you're testing. In this tutorial, we have one instance of an expert -- but RL training is famously unstable, so another training run with another random seed would likely produce a slightly different result. So ideally, we would like to repeat this procedure several times, training the same agent with different random seeds, and then compare the average performance of the two agents.\n",
    "\n",
    "Even then, this is just on one environment, with one algorithm. So be wary of generalizing your results too much."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also use the same method to compare different algorithms. While CartPole is pretty easy, we can make it more difficult by decreasing the number of episodes in our dataset, and generating them with a suboptimal policy:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rollouts = rollout.rollout(\n",
    "    expert,\n",
    "    DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
    "    rollout.make_sample_until(min_timesteps=None, min_episodes=1),\n",
    "    rng=rng,\n",
    ")\n",
    "transitions = rollout.flatten_trajectories(rollouts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try training a behavior cloning algorithm on this dataset.\n",
    "\n",
    "Note that for DAgger, we have to cheat a little bit -- it's allowed to use the expert policy to generate additional data.\n",
    "For the purposes of this tutorial, we'll stick with this to avoid spending hours training an expert for a more complex environment.\n",
    "\n",
    "So while this little experiment isn't definitive proof that DAgger is better than BC, you can use the same method to compare any two algorithms."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.algorithms.dagger import SimpleDAggerTrainer\n",
    "import tempfile\n",
    "\n",
    "bc_trainer = bc.BC(\n",
    "    observation_space=env.observation_space,\n",
    "    action_space=env.action_space,\n",
    "    demonstrations=transitions,\n",
    "    rng=rng,\n",
    ")\n",
    "\n",
    "bc_trainer.train(n_epochs=1)\n",
    "\n",
    "\n",
    "with tempfile.TemporaryDirectory(prefix=\"dagger_example_\") as tmpdir:\n",
    "    print(tmpdir)\n",
    "    dagger_bc_trainer = bc.BC(\n",
    "        observation_space=env.observation_space,\n",
    "        action_space=env.action_space,\n",
    "        rng=np.random.default_rng(),\n",
    "    )\n",
    "    dagger_trainer = SimpleDAggerTrainer(\n",
    "        venv=DummyVecEnv([lambda: RolloutInfoWrapper(env)]),\n",
    "        scratch_dir=tmpdir,\n",
    "        expert_policy=expert,\n",
    "        bc_trainer=dagger_bc_trainer,\n",
    "        rng=np.random.default_rng(),\n",
    "    )\n",
    "\n",
    "    dagger_trainer.train(5000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After training both BC and DAgger, let's compare their performances again! We expect DAgger to be better -- after all, it's a more advanced algorithm. But is it significantly better?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bc_rewards, _ = evaluate_policy(bc_trainer.policy, env, 10, return_episode_rewards=True)\n",
    "dagger_rewards, _ = evaluate_policy(\n",
    "    dagger_trainer.policy, env, 10, return_episode_rewards=True\n",
    ")\n",
    "\n",
    "significant = is_significant_reward_improvement(bc_rewards, dagger_rewards, 0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"BC rewards: {bc_rewards}\")\n",
    "print(f\"DAgger rewards: {dagger_rewards}\")\n",
    "\n",
    "print(\n",
    "    f\"Our DAgger agent is {'NOT ' if not significant else ''}significantly better than BC.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you increased the number of training iterations for the expert (in the first cell of the tutorial), you should see that DAgger indeed performs better than BC. If you didn't, you likely see the opposite result. Yet another reason to be careful when interpreting results!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's take a moment, to remember the limitations of this experiment. We're comparing two algorithms on one environment, with one dataset. We're also using a suboptimal expert policy, which might not be the best choice for BC. If you want to convince yourself that DAgger is better than BC, you should pick out a more complex environment, you should run this experiment several times, with different random seeds and perform some hyperparameter optimization to make sure we're not just using unlucky hyperparameters. At the end, we would also need to run the same hypothesis test across average returns of several independent runs.\n",
    "\n",
    "But now you have all the pieces of the puzzle to do that!"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "bd378ce8f53beae712f05342da42c6a7612fc68b19bea03b52c7b1cdc8851b5f"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
